﻿/********************************************************
 *  ██████╗  ██████╗████████╗██╗
 * ██╔════╝ ██╔════╝╚══██╔══╝██║
 * ██║  ███╗██║        ██║   ██║
 * ██║   ██║██║        ██║   ██║
 * ╚██████╔╝╚██████╗   ██║   ███████╗
 *  ╚═════╝  ╚═════╝   ╚═╝   ╚══════╝
 * Geophysical Computational Tools & Library (GCTL)
 *
 * Copyright (c) 2023  Yi Zhang (yizhang-geo@zju.edu.cn)
 *
 * GCTL is distributed under a dual licensing scheme. You can redistribute 
 * it and/or modify it under the terms of the GNU Lesser General Public 
 * License as published by the Free Software Foundation, either version 2 
 * of the License, or (at your option) any later version. You should have 
 * received a copy of the GNU Lesser General Public License along with this 
 * program. If not, see <http://www.gnu.org/licenses/>.
 * 
 * If the terms and conditions of the LGPL v.2. would prevent you from using 
 * the GCTL, please consider the option to obtain a commercial license for a 
 * fee. These licenses are offered by the GCTL's original author. As a rule, 
 * licenses are provided "as-is", unlimited in time for a one time fee. Please 
 * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget 
 * to include some description of your company and the realm of its activities. 
 * Also add information on how to contact you by electronic and paper mail.
 ******************************************************/

#ifndef _GCTL_NATIVE_IO_H
#define _GCTL_NATIVE_IO_H

#include "../core.h"
#include "../utility.h"

namespace gctl
{
    /**
     * @brief      从GCTL定义的二进制文件（后缀为.ar）导入数据到数组
     * 
     * @warning    此函数在导入元素前会先清空原来的数组，再重新初始化数组
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param      out_arr    输出的数组
     * @param[in]  info_ptr   存储头信息的字符串指针，若不为空则保存头信息到字符串
     *
     * @tparam     T          模板类型
     */
    template <typename T>
    void read_binary2array(std::string filename, array<T> &out_arr, 
        std::string *info_ptr = nullptr)
    {
        std::ifstream infile;
        open_infile(infile, filename, ".ar", std::ios::in|std::ios::binary);

        int info_size;
        std::string head_info, ele_name, type_name = typeid(T).name();

        infile.read((char*)&info_size, sizeof(int));
        if (info_size != 0)
        {
            head_info.resize(info_size);
            infile.read((char*)head_info.c_str(), info_size);
        }

        if (info_ptr != nullptr)
        {
            *info_ptr = head_info;
        }

        infile.read((char*)&info_size, sizeof(int));
        ele_name.resize(info_size);
        infile.read((char*)ele_name.c_str(), info_size);

        // 首先读入一个整形并与元素的大小进行比较
        int ele_size;
        infile.read((char*)&ele_size, sizeof(int));
        if (ele_size != sizeof(T) || ele_name != type_name)
        {
            throw runtime_error("Incompatible element size or name. From gctl::read_binary2array(...)");
        }

        // 读入元素个数
        int in_size;
        infile.read((char*)&in_size, sizeof(int));
        if (in_size <= 0)
        {
            throw runtime_error("No element found. From gctl::read_binary2array(...)");
        }

        out_arr.resize(in_size);
        infile.read((char*)out_arr.get(), sizeof(T)*in_size);
        infile.close();
        return;
    }

    /**
     * @brief      保存数组数据到GCTL定义的二进制文件（后缀为.ar）
     * 
     * 二进制文件格式：
     * 后缀名称：.ar
     * 格式：1. 整形值，等于头信息的长度
     * 格式：2. 字符串，头信息（所保存数据的必要说明）
     * 格式：3. 整形值，等于数组元素的typeid名称的长度
     * 格式：4. 字符串，数组元素的typeid名称（验证文件所保存数据是否为正确的数据类型）
     * 格式：5. 整形值，数组元素的sizeof长度（验证文件所保存数据是否为正确的数据类型）
     * 格式：6. 整形值，输出的数组元素的个数
     * 格式：7. 输出的数组元素的列表
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param[in]  in_arr     输入的数组
     * @param[in]  head_info  需要保存的头信息
     * @param[in]  st         数组保存的起始位置，默认为数组开始位置
     * @param[in]  ed         数组保存的结束位置，默认为数组结束位置
     * 
     *  @tparam     T          模板类型
     */
    template <typename T>
    void save_array2binary(std::string filename, const array<T> &in_arr, 
        std::string head_info = "No Head Info", int st = 0, int ed = 0)
    {
        if (in_arr.empty())
        {
            throw length_error("The operating array is empty. From gctl::save_array2binary(...)");
        }

        if (st < 0)
        {
            throw out_of_range("Invalid index. From gctl::save_array2binary(...)");
        }

        if (ed == 0) ed = in_arr.size();

        if (ed < 0 || ed > in_arr.size() || st >= ed)
        {
            throw out_of_range("Invalid index. From gctl::save_array2binary(...)");
        }

        std::ofstream outfile;
        open_outfile(outfile, filename, ".ar", std::ios::out|std::ios::binary);

        int info_size = head_info.size();
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)head_info.c_str(), info_size);

        const std::type_info &tinfo = typeid(T);
        std::string ele_name = tinfo.name();
        info_size = ele_name.size();
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)ele_name.c_str(), info_size);

        int ele_size = sizeof(T);
        outfile.write((char*)&ele_size, sizeof(int));

        int out_size = ed - st;
        outfile.write((char*)&out_size, sizeof(int));
        outfile.write((char*)in_arr.get(st), sizeof(T)*out_size);
        outfile.close();
        return;
    }

    /**
     * @brief      从GCTL定义的二进制文件（后缀为.mat）导入数据到数组
     * 
     * @warning    此函数在导入元素前会先清空原来的数组，再重新初始化数组
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param      out_arr    输出的数组
     * @param[in]  info_ptr   存储头信息的字符串指针，若不为空则保存头信息到字符串
     *
     * @tparam     T          模板类型
     */
    template <typename T>
    void read_binary2matrix(std::string filename, matrix<T> &out_arr, 
        std::string *info_ptr = nullptr)
    {
        std::ifstream infile;
        open_infile(infile, filename, ".mat", std::ios::in|std::ios::binary);

        int info_size;
        std::string head_info, ele_name, type_name = typeid(T).name();

        infile.read((char*)&info_size, sizeof(int));
        if (info_size != 0)
        {
            head_info.resize(info_size);
            infile.read((char*)head_info.c_str(), info_size);
        }

        if (info_ptr != nullptr)
        {
            *info_ptr = head_info;
        }

        infile.read((char*)&info_size, sizeof(int));
        ele_name.resize(info_size);
        infile.read((char*)ele_name.c_str(), info_size);

        // 首先读入一个整形并与元素的大小进行比较
        int ele_size;
        infile.read((char*)&ele_size, sizeof(int));
        if (ele_size != sizeof(T) || ele_name != type_name)
        {
            throw runtime_error("Incompatible element size. From gctl::read_binary2matrix(...)");
        }

        // 读入元素个数
        int in_rowsize, in_colsize;
        infile.read((char*)&in_rowsize, sizeof(int));
        infile.read((char*)&in_colsize, sizeof(int));
        if (in_rowsize <= 0 || in_colsize <= 0)
        {
            throw runtime_error("No element found. From gctl::read_binary2matrix(...)");
        }

        out_arr.resize(in_rowsize, in_colsize);
        for (int i = 0; i < in_rowsize; i++)
        {
            infile.read((char*)out_arr.get(i), sizeof(T)*in_colsize);
        }

        infile.close();
        return;
    }

    /**
     * @brief      保存数组数据到GCTL定义的二进制文件（后缀为.mat）
     * 
     * 二进制文件格式：
     * 后缀名称：.mat
     * 格式：1. 整形值，等于头信息的长度
     * 格式：2. 字符串，头信息（所保存数据的必要说明）
     * 格式：3. 整形值，等于数组元素的typeid名称的长度
     * 格式：4. 字符串，数组元素的typeid名称（验证文件所保存数据是否为正确的数据类型）
     * 格式：5. 整形值，数组元素的sizeof长度（验证文件所保存数据是否为正确的数据类型）
     * 格式：6. 整形值，输出的数组元素的行数
     * 格式：7. 整形值，输出的数组元素的列数
     * 格式：8. 输出的数组元素的列表
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param[in]  in_arr     输入的数组
     * @param[in]  head_info  需要保存的头信息
     * @param[in]  r_st       开始的行索引，默认为数组开始位置
     * @param[in]  r_ed       结束的行索引，默认为数组结束位置
     * @param[in]  c_st       开始的列索引，默认为数组开始位置
     * @param[in]  c_ed       开始的列索引，默认为数组开始位置
     *
     * @tparam     T          模板类型
     */
    template <typename T>
    void save_matrix2binary(std::string filename, const matrix<T> &in_arr, 
        std::string head_info = "No Head Info", int r_st = 0, int r_ed = 0, 
        int c_st = 0, int c_ed = 0)
    {
        if (in_arr.empty())
        {
            throw runtime_error("The input array is empty. From gctl::save_matrix2binary(...)");
        }

        if (r_st < 0 || c_st < 0 || r_ed < 0 || c_ed < 0 || 
            r_ed > in_arr.row_size() || c_ed > in_arr.col_size())
        {
            throw invalid_argument("Invalid index. From gctl::save_matrix2binary(...)");
        }

        if (r_ed == 0) r_ed = in_arr.row_size();
        if (c_ed == 0) c_ed = in_arr.col_size();

        std::ofstream outfile;
        open_outfile(outfile, filename, ".mat", std::ios::out|std::ios::binary);

        int info_size = head_info.size();
        int ele_size = sizeof(T);
        int out_rowsize = r_ed - r_st;
        int out_colsize = c_ed - c_st;
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)head_info.c_str(), info_size);

        const std::type_info &tinfo = typeid(T);
        std::string ele_name = tinfo.name();
        info_size = ele_name.size();
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)ele_name.c_str(), info_size);

        outfile.write((char*)&ele_size, sizeof(int));
        outfile.write((char*)&out_rowsize, sizeof(int));
        outfile.write((char*)&out_colsize, sizeof(int));
        for (int i = r_st; i < r_ed; i++)
        {
            for (size_t j = c_st; j < c_ed; j++)
            {
                outfile.write((char*)&in_arr[i][j], sizeof(T));
            }
            //outfile.write((char*)in_arr.get(i, c_st), sizeof(T)*(c_ed-c_st));
        }
        outfile.close();
        return;
    }

    /**
     * @brief      从GCTL定义的二进制文件（后缀为.spm）导入数据到稀疏矩阵
     * 
     * @warning    此函数在导入元素前会先清空原来的矩阵，再重新初始化数组
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param      out_mat    输出的稀疏矩阵
     * @param[in]  info_ptr   存储头信息的字符串指针，若不为空则保存头信息到字符串
     *
     * @tparam     T          模板类型
     */
    template <typename T>
    void read_binary2spmat(std::string filename, spmat<T> &out_mat, 
        std::string *info_ptr = nullptr)
    {
        std::ifstream infile;
        open_infile(infile, filename, ".spm", std::ios::in|std::ios::binary);

        int info_size;
        std::string head_info, ele_name, type_name = typeid(T).name();

        infile.read((char*)&info_size, sizeof(int));
        if (info_size != 0)
        {
            head_info.resize(info_size);
            infile.read((char*)head_info.c_str(), info_size);
        }

        if (info_ptr != nullptr)
        {
            *info_ptr = head_info;
        }

        infile.read((char*)&info_size, sizeof(int));
        ele_name.resize(info_size);
        infile.read((char*)ele_name.c_str(), info_size);

        // 首先读入一个整形并与元素的大小进行比较
        int ele_size;
        infile.read((char*)&ele_size, sizeof(int));
        if (ele_size != sizeof(T) || ele_name != type_name)
        {
            throw runtime_error("Incompatible element size or name. From gctl::read_binary2spmat(...)");
        }
        // 读入元素个数
        int ele_num, in_rnum, in_cnum;
        T in_zeroval;
        infile.read((char*)&in_rnum, sizeof(int));
        infile.read((char*)&in_cnum, sizeof(int));
        infile.read((char*)&ele_num, sizeof(int));
        infile.read((char*)&in_zeroval, sizeof(T));

        out_mat.clear();
        out_mat.malloc(in_rnum, in_cnum, in_zeroval);

        int tmp_rid, tmp_cid;
        T tmp_val;

        for (int i = 0; i < ele_num; i++)
        {
            infile.read((char*)&tmp_rid, sizeof(int));
            infile.read((char*)&tmp_cid, sizeof(int));
            infile.read((char*)&tmp_val, sizeof(T));

            out_mat.insert(tmp_rid, tmp_cid, tmp_val);
        }
        infile.close();
        return;
    }

    /**
     * @brief      保存稀疏矩阵数据到GCTL定义的二进制文件（后缀为.spm）
     * 
     * 二进制文件格式：
     * 后缀名称：.ar
     * 格式：1. 整形值，等于头信息的长度
     * 格式：2. 字符串，头信息（所保存数据的必要说明）
     * 格式：3. 整形值，等于数组元素的typeid名称的长度
     * 格式：4. 字符串，数组元素的typeid名称（验证文件所保存数据是否为正确的数据类型）
     * 格式：5. 整形值，数组元素的sizeof长度（验证文件所保存数据是否为正确的数据类型）
     * 格式：6. 整形值，输出的数组元素的个数
     * 格式：7. 输出的数组元素的列表
     *
     * @param[in]  filename   文件名(不包含文件后缀)
     * @param[in]  in_arr     输入的数组
     * @param[in]  head_info  需要保存的头信息
     * @param[in]  st         数组保存的起始位置，默认为数组开始位置
     * @param[in]  ed         数组保存的结束位置，默认为数组结束位置
     * 
     *  @tparam     T          模板类型
     */
    template <typename T>
    void save_spmat2binary(std::string filename, const spmat<T> &in_mat, 
        std::string head_info = "No Head Info")
    {
        std::ofstream outfile;
        open_outfile(outfile, filename, ".spm", std::ios::out|std::ios::binary);

        int info_size = head_info.size();
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)head_info.c_str(), info_size);

        const std::type_info &tinfo = typeid(T);
        std::string ele_name = tinfo.name();
        info_size = ele_name.size();
        outfile.write((char*)&info_size, sizeof(int));
        outfile.write((char*)ele_name.c_str(), info_size);

        int ele_size = sizeof(T);
        outfile.write((char*)&ele_size, sizeof(int));

        int r_num = in_mat.row_size();
        int c_num = in_mat.col_size();
        int n_num = in_mat.ele_size();
        T zero_val = in_mat.zero_value();
        outfile.write((char*)&r_num, sizeof(int));
        outfile.write((char*)&c_num, sizeof(int));
        outfile.write((char*)&n_num, sizeof(int));
        outfile.write((char*)&zero_val, sizeof(T));

        array<size_t> rows, cols;
        array<T> vals;
        in_mat.export_coo(rows, cols, vals);

        for (size_t i = 0; i < rows.size(); i++)
        {
            outfile.write((char*)&rows[i], sizeof(size_t));
            outfile.write((char*)&cols[i], sizeof(size_t));
            outfile.write((char*)&vals[i], sizeof(T));
        }
        outfile.close();
        return;
    }
}

#endif // _GCTL_NATIVE_IO_H